from dataclasses import dataclass, field
from typing import NamedTuple

import numpy as np
import torch


class NoiseParams(NamedTuple):
    level: float = 0.
    prob_0to1: float = 0.
    prob_1to0: float = 0.


BACKENDS = {}


def register_backend(name):
    def wrapped(func):
        BACKENDS[name] = func
        return func
    return wrapped


class MeasureEnergy(torch.autograd.Function):
    @staticmethod
    def forward(ctx, angles, backend, n_readout):
        ctx.save_for_backward(angles)
        ctx.backend = backend
        ctx.n_readout = n_readout
        mean, var = backend.measure_energy(angles, n_readout)
        return torch.from_numpy(mean), torch.from_numpy(var)

    @staticmethod
    def backward(ctx, grad_output_mean, grad_output_var):
        angles, = ctx.saved_tensors
        shape = angles.shape

        # gradient is measured n_reps times to reduce variance
        angles = angles.repeat_interleave(ctx.backend.n_reps, 0)
        grad_input = ctx.backend.parameter_shift_gradient(angles, ctx.n_readout)
        grad_input = grad_input.reshape(ctx.backend.n_reps, *shape).mean(0)
        grad_output = grad_output_mean[(...,) + (None,) * (len(grad_input.shape) - 1)]
        return (grad_output * grad_input, None, None)


@dataclass
class EnergyBackend:
    '''Abstract class to define energy functions.'''
    n_qbits: int = 5
    n_layers: int = 3
    j: tuple = (-1., 0., 0.)
    h: tuple = (0., 0., -1.)
    circuit: str = 'esu2'
    mom_sector: int = 1
    noise_params: NoiseParams = field(default_factory=NoiseParams)
    n_reps: int = 1
    pbc: bool = False
    rng: np.random.Generator = None

    def __call__(self, angles, n_readout):
        is_numpy = isinstance(angles, np.ndarray)
        if is_numpy:
            angles = torch.from_numpy(angles)
        result = MeasureEnergy.apply(angles, self, n_readout)

        if is_numpy:
            result = tuple(elem.detach().cpu().numpy() for elem in result)
        return result

    def measure_energy(self, angles, n_readout):
        """Measure the expected energy and variance of the hamiltonian.

        Paramaters
        ----------
        angles: :py:obj:`np.ndarray`
            The gate parameters.
        n_readout: int
            Number of shots (readouts) used to measure the hamiltonian.

        Returns
        -------
        energies: :py:obj:`np.ndarray`
            The mean energy values over the shots.
        variances: :py:obj:`np.ndarray`
            The variance of the energy values over the shots.

        """
        raise NotImplementedError

    def measure_overlap(self, angles, exact_wf):
        """Computes the overlap between a state vector and the resulting state vector given a circuit and its angles.

        Parameters
        ----------
        exact_wf: obj:`numpy.ndarray`
            State vector as a complex numpy array
        circuit: QuantumCircuit
            Quantum circuit encoding the wave function
        """
        raise NotImplementedError

    def parameter_shift_gradient(self, angles, n_readout):
        """Compute the gradient of the mean energy wrt. the gate parameters `angles` using the parameter shift rule.

        Paramaters
        ----------
        angles: :py:obj:`np.ndarray`
            The gate parameters.
        n_readout: int
            Number of shots (readouts) used to measure the hamiltonian.

        Returns
        -------
        energies: :py:obj:`np.ndarray`
            The gradient of the mean energy values over the shots wrt. the gate parameters `angles`.

        Note
        ----
        A future release may compute the variance also, in case it is needed.
        The variance of the parameter shift requires covariances of the individual shots, which is not possible to get
        through the current way of the energy estimation.

        """
        raise NotImplementedError
